In [ ]:
import importlib
import seaborn as sns
import matplotlib.pyplot as plt
import draft.models.MobileNet.runner_scripts.trainer as trainer
import draft.models.MobileNet.classifier as classifier
import draft.models.MobileNet.data_loader as data_loader
import draft.models.MobileNet.metrics as metrics
import os
import Notebooks.utils.utils as utils
import Notebooks.utils.error_analysis as error_analysis
import pandas as pd
from IPython.core.display import display, HTML
In [262]:
utils.fix_cwd()
sns.set_theme(style="darkgrid", palette="pastel")
plt.style.use("fivethirtyeight")

Performance Analysis¶

We'll start by analyzing our best performing model after initial tuning. After ~100-200 runs to select the:

  • best performing scheduler {explain which}
  • optimal set of transformations/augmentations {explain which}

{we've arrived at this configuration by only trying to maximize the high level model performance:

  • {total weighted loss (combined from normalized gender and age prediction loss}
  • gender predicitons accuracy
  • MAE for age predictions
In [160]:
 
In [164]:
%%html
<iframe src="https://wandb.ai/qqwy/ag_classifier_main/reports/Best-Iteration-1-model-graceful-hill-257---Vmlldzo4ODIwODg5" style="border:none;height:1024px;width:100%">
In [167]:
data = {
    'Parameter': [
        'model_type', 'lr_scheduler', 'anneal_strategy', 'base_lr', 'batch_size',
        'div_factor', 'dropout', 'final_div_factor', 'freeze_epochs', 'l1_lambda',
        'max_lr', 'num_epochs', 'override_cycle_epoch_count', 'weight_decay',
        'pct_start', 'train_path', 'val_path'
    ],
    'Value': [
        'mobilenet_v3_small', 'one_cycle', 'cos', 0.0068893981577029285, 256,
        24, 0.1, 2873, 0, 0.0001, 0.012321315111072404, 18, 15,
        0.00019323262043373016, 0.36685557351085574, 'dataset/train_8_folds_first',
        'dataset/test_2_folds_last'
    ]
}

pd.DataFrame(data)
Out[167]:
Parameter Value
0 model_type mobilenet_v3_small
1 lr_scheduler one_cycle
2 anneal_strategy cos
3 base_lr 0.006889
4 batch_size 256
5 div_factor 24
6 dropout 0.1
7 final_div_factor 2873
8 freeze_epochs 0
9 l1_lambda 0.0001
10 max_lr 0.012321
11 num_epochs 18
12 override_cycle_epoch_count 15
13 weight_decay 0.000193
14 pct_start 0.366856
15 train_path dataset/train_8_folds_first
16 val_path dataset/test_2_folds_last
Main Observations¶
  • using one_cycle as our LR scheduler has allowed us to achieve convergence in only ~15 epochs while providing signficantly better performance than reduce_on_plateau or step_lr were able to achieve even after 30-40 epochs.

  • freeze_epochs

  • Model was fine-tuned using pretrained weights (IMAGENET1K_V1). We've found that training MobileNet from scratch (using randomized initial weights) can provide comparable or only slightly inferior performance with the UTK dataset. We've still chosen to use the pretrained weights because:

    • the model still performs a bit better (0.015 higher accuracy, ~0.2 lower MAE)
    • because the model was trained with a higher variety of images in different condition the model should still perform better (or not worse) on images of faces in real-world conditions.
In [5]:
BASE_MODEL_NAME = "final_prod_z5yxudkl_graceful-hill-257_19_0.9310.pth"
OVERSAMPLE_AUG_MODEL_NAME = "full_aug_small_production_v1.pth"
NOT_PRETRAINED_MODEL_NAME = "NO_WEIGHTS_full_dynamic_aug_tune_18_cycle+3_sage-planet-309_20_0.9259.pth"
In [6]:
test_config = {
    'ds_path': 'dataset/test_2_folds_last',
    'batch_size': 512,
}
In [179]:
 
In [ ]:
base_model = trainer.load_model(BASE_MODEL_NAME)
base_model.eval()
;
improved_model = trainer.load_model(OVERSAMPLE_AUG_MODEL_NAME)
improved_model.eval()
In [ ]:
data_module_base = data_loader.create_dataloaders(test_config, mode='test')
data_module_base.setup('test')
predictions_base = classifier.predict_with_model(base_model, data_module_base)
In [ ]:
data_module_improved = data_loader.create_dataloaders(test_config, mode='test')
data_module_improved.setup('test')
predictions_improved = classifier.predict_with_model(improved_model, data_module_improved)
In [ ]:
importlib.reload(error_analysis)

image_data_path = 'dataset/image_entropy_summary.csv'
image_data = pd.read_csv(image_data_path)

merged_data_base = error_analysis.sync_predictions_with_image_data(predictions_base, image_data)
merged_data_improved = error_analysis.sync_predictions_with_image_data(predictions_improved, image_data)


image_quality_metrics_base = error_analysis.evaluate_by_image_quality(merged_data_base)
image_quality_metrics_improved = error_analysis.evaluate_by_image_quality(merged_data_improved)
;
In [ ]:
 

Performance¶

In [172]:
importlib.reload(metrics)
evaluation_results_improved = metrics.evaluate_predictions(predictions_improved)

evaluation_results_base = metrics.evaluate_predictions(predictions_base)
evaluation_results_base[("gender_metrics")].round(3)
Out[172]:
Female Male Overall
Support 2353.000 2387.000 4740.000
Accuracy 0.931 0.931 0.931
Precision 0.924 0.938 0.931
Recall 0.938 0.924 0.931
F1-score 0.931 0.931 0.931
AUC-ROC NaN NaN 0.981
PR-AUC NaN NaN 0.978
Log Loss NaN NaN 0.179
Brier Score NaN NaN NaN
In [173]:
evaluation_results_base["age_metrics"]
Out[173]:
Value
MAE 5.105901
MSE 54.144762
RMSE 7.358312
R-squared 0.862191
MAPE 25.161557

We've been able to achieve an accuracy of ~93% for gender predictions and Age MAE (Mean Absolute Error) of around 5.1 years.

In [181]:
utils.get_baselines_table()
Out[181]:
Model Age Estimation (MAE) Gender Classification (Accuracy)
0 XGBoost (+feat. extraction) 5.89 93.80
1 SVC(..) 5.49 94.64
2 VGG_f 4.86 93.42
3 ResNet50_f 4.65 94.64
4 SENet50_f 4.58 94.90

(*https://arxiv.org/pdf/2110.12633)

While our model still lags

*https://arxiv.org/pdf/2110.12633

That seems like a reasonable good results when using such a small model directly (i.e. no ensemble/metal-models).

(Specific dataset split and preprocessing)

In [ ]:
 
In [183]:
importlib.reload(utils)
utils.model_desc_table()
Out[183]:
VGG16 ResNet50 MobileNetV3-Small
Metric
Parameter Count ~138 million ~25.6 million ~2.5 million
Model Size (PyTorch, FP32) ~528 MB ~98 MB ~10 MB
Inference Speed (relative) 1x (baseline) ~2.5x faster ~10x faster
FLOPs ~15.5 billion ~4.1 billion ~56 million
Approx. Memory Usage (inference) 1x ~0.6x ~0.15x

Overall this is not necessarily particularly exceptional, the UTK Face dataset is relatively small and specific compared to general image classification tasks (which effectively can level the playing field for smaller models) and there are several other studies/benchmarks showing that show MobileNet variants performing competitively with larger models on simple task like this (while performing signficantly worse at more compelx tasks like emotion detecting or face recognition):

e.g. according to Savchenko, A. V. (2024). arXiv. https://ar5iv.labs.arxiv.org/html/2103.17107 MobileNet without any fine-tuning using the UTKFace dataset (i.e. and full UTKFace was used for testing) actually outperformed VGG-16 & ResNet-50.

Age Classification¶

In [274]:
importlib.reload(error_analysis)

error_analysis.confusion_matrix_plot_v2(merged_data_base, "true_gender", "gender_pred", class_labels=["Male", "Female"])
Out[274]:
<Axes: title={'center': 'Confusion Matrix with Percentage Accuracy'}, xlabel='Predicted label', ylabel='True label'>
No description has been provided for this image
In [274]:
 
In [235]:
 
In [274]:
 
In [236]:
 

Accuracy of Gender Prediction by Age Group¶

In [277]:
evaluation_results_base['gender_accuracy_by_age']
Out[277]:
Total Correct Accuracy
Age_Group
0-4 444 307 0.6914
4-14 261 215 0.8238
14-24 636 604 0.9497
24-30 1228 1187 0.9666
30-40 865 837 0.9676
40-50 399 393 0.9850
50-60 420 409 0.9738
60-70 229 218 0.9520
70-80 156 149 0.9551
80+ 102 94 0.9216

We can see that gender prediction accuracy is reasonably high across all ranges except young children. Realistically it's unlikely we can do anything about that, facial features of babies tend to be very different from adults. Potentially it might be worth investigating building a separate model for them but it's unlikely that it would achieved very high performance either.

Summary of Age Prediction¶

In [21]:
evaluation_results_base['age_statistics']
Out[21]:
True Age Predicted Age
Mean 33.308439 32.147823
Median 29.000000 28.514690
Min 1.000000 -2.139822
Max 116.000000 95.214233

Age Prediction by Age Group¶

In [25]:
importlib.reload(metrics)
evaluation_results_base['performance_by_age_bin']
Out[25]:
Age_Group Support Age_MAE Age_MSE Age_RMSE Age_R-squared Age_MAPE
0 0-4 444 1.588580 11.325658 3.365361 -9.241579 99.745904
1 4-14 261 4.011655 34.033093 5.833789 -3.743251 46.700869
2 14-24 636 4.171022 32.965802 5.741585 -2.937213 21.156784
3 24-30 1228 3.720786 30.006521 5.477821 -10.167695 13.674633
4 30-40 865 6.270144 63.924114 7.995256 -7.162335 17.644973
5 40-50 399 7.749943 96.742555 9.835779 -10.194667 16.942367
6 50-60 420 7.311122 91.486462 9.564856 -11.248783 13.271226
7 60-70 229 6.725516 80.393407 8.966237 -8.236708 10.369088
8 70-80 156 7.617475 105.892985 10.290432 -11.530508 10.082188
9 80+ 102 8.947648 173.258202 13.162758 -3.118748 9.777900

This table shows one of the flaws of using MAE are our target metric, it downplays inaccurate predictions for children and potential exaggerates them as the subject age increases.

i.e. miss-classifying a newborn as a 5-year-old child or the other way around is much bigger error than doing the same when the subject is over 70.

MAPE (Mean Absolute Percentage Error) would pontetially be a better metric, however it can (and clearly is) be problematic for very young ages (near zero) as it leads to extremely large or undefined percentages.

In [276]:
def process_age_groups(df, true_col, pred_col):
    age_groups = sorted(df[true_col].unique())
    ranges = [(float(g.split('-')[0]), float('inf') if g.endswith('inf') else float(g.split('-')[1])) for g in age_groups]
    
    df['true_group_index'] = (pd.Categorical(df[true_col], categories=age_groups).codes).astype(int)
    df['pred_group_index'] = pd.cut(df[pred_col].map(lambda x: max(x, 0.01)), bins=[r[0] for r in ranges] + [float('inf')], labels=False).astype(int)
    
    return df, age_groups

df, class_labels = process_age_groups(merged_data_base, 'age_group', 'age_pred')

importlib.reload(error_analysis)
error_analysis.confusion_matrix_plot_v2(df, "true_group_index", "pred_group_index", class_labels=class_labels)
Out[276]:
<Axes: title={'center': 'Confusion Matrix with Percentage Accuracy'}, xlabel='Predicted label', ylabel='True label'>
No description has been provided for this image
In [275]:
 
In [273]:
importlib.reload(error_analysis)
error_analysis.evaluate_age_prediction(merged_data_base["true_age"], merged_data_base["age_pred"], bins=metrics.DEFAULT_AGE_BINS)
/mnt/v/projects/DL_s3/Notebooks/utils/error_analysis.py:170: FutureWarning: 

`shade` is now deprecated in favor of `fill`; setting `fill=True`.
This will become an error in seaborn v0.14.0; please update your code.

  sns.kdeplot(x='True_Age', y='Error', data=df, ax=axs[0, 0], cmap="YlOrRd", shade=True, cbar=True)
/mnt/v/projects/DL_s3/Notebooks/utils/error_analysis.py:182: FutureWarning: 

`shade` is now deprecated in favor of `fill`; setting `fill=True`.
This will become an error in seaborn v0.14.0; please update your code.

  sns.kdeplot(x='True_Age', y='Predicted_Age', data=df, ax=axs[1, 0], cmap="YlOrRd", shade=True, cbar=True)
/mnt/v/projects/DL_s3/Notebooks/utils/error_analysis.py:209: FutureWarning: The default of observed=False is deprecated and will be changed to True in a future version of pandas. Pass observed=False to retain current behavior or observed=True to adopt the future default and silence this warning.
  ('median', 'median'),
/mnt/v/projects/DL_s3/Notebooks/utils/error_analysis.py:244: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
  
No description has been provided for this image
In [ ]:
# Usage
# error_analysis.evaluate_age_prediction(predictions['true_age'], predictions['age_pred'])  
In [ ]:
# TODO: split
# these
# into
# separate
# plots...

LIME¶

In [27]:
from skimage.color import gray2rgb, rgb2gray  # since the code wants color images
In [29]:
from PIL import Image

img = Image.open("dataset/full/40_0_0_20170117151450653.jpg.chip.jpg").convert("RGB")
In [30]:
import lime
import lime.lime_image as lime_image
In [33]:
# from lime.wrappers.scikit_image import SegmentationAlgorithm
# 
# explainer = lime_image.LimeImageExplainer(verbose=False)
# segmenter = SegmentationAlgorithm('slic', n_segments=100, compactness=1, sigma=1)
In [33]:
 
In [33]:
 

Solving Age Balancing¶

In [34]:
import torch
import draft.models.MobileNet.data_defs as data_defs
import draft.models.MobileNet.metrics as metrics

importlib.reload(data_defs)

from lime import lime_image
from skimage.segmentation import mark_boundaries
import numpy as np
In [ ]:
 
In [35]:
importlib.reload(error_analysis)
Out[35]:
<module 'Notebooks.utils.error_analysis' from '/mnt/v/projects/DL_s3/Notebooks/utils/error_analysis.py'>
In [36]:
# Usage
image_files = [
    "dataset/full/3_1_0_20170109193055962.jpg.chip.jpg",
    # "dataset/full/8_1_0_20170109202254541.jpg.chip.jpg",
    "dataset/full/15_0_0_20170104012346994.jpg.chip.jpg",
    "dataset/full/17_1_0_20170109214008165.jpg.chip.jpg",
    "dataset/full/31_1_4_20170117203039631.jpg.chip.jpg",
    "dataset/full/40_0_0_20170117151450653.jpg.chip.jpg",
    "dataset/full/50_0_0_20170111181750459.jpg.chip.jpg",
    # "dataset/full/57_1_0_20170110131940730.jpg.chip.jpg",
    # "dataset/full/68_1_0_20170110183125200.jpg.chip.jpg",
    "dataset/full/79_0_0_20170111222432817.jpg.chip.jpg",
    "dataset/full/110_0_0_20170112213500903.jpg.chip.jpg",
]

test_set = error_analysis.process_images(base_model, image_files)
  0%|          | 0/500 [00:00<?, ?it/s]
  0%|          | 0/500 [00:00<?, ?it/s]
  0%|          | 0/500 [00:00<?, ?it/s]
  0%|          | 0/500 [00:00<?, ?it/s]
  0%|          | 0/500 [00:00<?, ?it/s]
  0%|          | 0/500 [00:00<?, ?it/s]
  0%|          | 0/500 [00:00<?, ?it/s]
  0%|          | 0/500 [00:00<?, ?it/s]
In [126]:
importlib.reload(error_analysis)
error_analysis.display_grid(test_set, scale=0.35)
No description has been provided for this image
Figure size: 840x2240 px
In [127]:
# error_analysis.display_grid(test_set)
importlib.reload(error_analysis)
misclassified_files = error_analysis.get_misclassified_from_predictions(predictions_base, data_module_base, test_config,
                                                                        n=8)
In [132]:
 
In [139]:
results_combined = error_analysis.process_images(base_model, misclassified_files.combined[:5])
results_age = error_analysis.process_images(base_model, [p for p in misclassified_files.age if
                                                         not p in misclassified_files.combined])
results_gender = error_analysis.process_images(base_model, [p for p in misclassified_files.gender if
                                                            not p in misclassified_files.combined])
  0%|          | 0/500 [00:00<?, ?it/s]
  0%|          | 0/500 [00:00<?, ?it/s]
  0%|          | 0/500 [00:00<?, ?it/s]
  0%|          | 0/500 [00:00<?, ?it/s]
  0%|          | 0/500 [00:00<?, ?it/s]
  0%|          | 0/500 [00:00<?, ?it/s]
  0%|          | 0/500 [00:00<?, ?it/s]
  0%|          | 0/500 [00:00<?, ?it/s]
  0%|          | 0/500 [00:00<?, ?it/s]
  0%|          | 0/500 [00:00<?, ?it/s]
  0%|          | 0/500 [00:00<?, ?it/s]
  0%|          | 0/500 [00:00<?, ?it/s]
  0%|          | 0/500 [00:00<?, ?it/s]
  0%|          | 0/500 [00:00<?, ?it/s]
  0%|          | 0/500 [00:00<?, ?it/s]
  0%|          | 0/500 [00:00<?, ?it/s]
  0%|          | 0/500 [00:00<?, ?it/s]
In [132]:
misclassified_files.combined[:5]
Out[132]:
['dataset/test_2_folds_last/111_1_0_20170120134646399.jpg.chip.jpg',
 'dataset/test_2_folds_last/1_1_0_20170109194452834.jpg.chip.jpg',
 'dataset/test_2_folds_last/9_0_0_20170110225030430.jpg.chip.jpg',
 'dataset/test_2_folds_last/8_0_1_20170114025855492.jpg.chip.jpg',
 'dataset/test_2_folds_last/41_1_1_20170117021604893.jpg.chip.jpg']

Most Misclassified Images (both gender/age)¶

In [130]:
importlib.reload(error_analysis)
error_analysis.display_grid(results_combined)
No description has been provided for this image
Figure size: 840x1400 px
In [144]:
 
In [142]:
 
In [142]:
error_analysis.display_grid(results_age)
No description has been provided for this image
Figure size: 840x1400 px

Misclassified Gender¶

Looking at gender specifically it's actually likely that our model performs better than the summarized results might imply.

The images above showcases where out model was least accurate, and we can see that all except one are likely cases of data being mislabeled in the original dataset (OR it's labeled accurately based on those individuals self-identity)

In [144]:
 
In [144]:
error_analysis.display_grid(results_gender)
No description has been provided for this image
Figure size: 840x1960 px
In [ ]:
 
In [ ]:
 
In [ ]:
 
In [111]:
 

We can see two main issues:

  1. Some images are poor quality or are strongly cropped. It's possible that we can solve this problem by using heuristics in preprocessing to exclude these samples from trained and test samples.

  2. We can see certain patterns related to race and age. The model is having issue classifying face of people who are non-white, possibly due to different facial features or skin color (although grayscale transform should partially fix that). Also, it's struggling with either very old people or children/babies possibly because of too small sample size and relatively more "androgynous" facial features in those groups. We'll attempt to fix this using augmentation in combination with oversampling (i.e. we'll use transforms to create additional samples for age bins which are underrepresented, additionally we'll use some of the color analysis from the EDA to also oversample the images of under-represented skin colors)

  3. Many samples are potentially mislabeled. It's possible that some of the samples are of people who self-identify as male/female while still retaining facial features, hairstyles etc. of the opposite gender. Or they are just mislabeled. In either case this part would be the hardest to solve.

Filtering Out "Invalid" Samples¶

We'l use a mix of metrics to try and determine which images are very poor quality, lack enough details to proper classification etc. :

BRISQUE (Blind/Referenceless Image Spatial Quality Evaluator):

A no-reference image quality assessment method. Uses scene statistics of locally normalized luminance coefficients to quantify possible losses of "naturalness" in the image due to distortions. Operates in the spatial domain.

Laplacian Variance:

A measure of image sharpness/blurriness. Uses the Laplacian operator to compute the second derivative of the image. Measures the variance of the Laplacian-filtered image.

FFT-based Blur Detection:

Uses Fast Fourier Transform to analyze the frequency components of an image. Applies a high-pass filter in the frequency domain and measures the remaining energy.

See the Data Analysis notebook for more details.

BRISQUE + Laplacian Variance¶

In [ ]:
worst_quality_images = error_analysis.get_worst_quality_images(merged_data)
worst_quality_images
In [ ]:
worst_quality_images_paths = format_image_paths(worst_quality_images["image_path"])
worst_quality_images_paths
In [ ]:
image_quality_metrics["brisque_score"].sort_values(by=["Bin"], ascending=False)
In [ ]:
merged_data[merged_data["brisque_score"] > 60]
In [ ]:
# merged_data[merged_data["brisque_score"] > 60]["image_path"]


low_brisque_images = [
    "dataset/full/17_0_0_20170117091447979.jpg.chip.jpg",
    "dataset/full/18_0_1_20170113175821404.jpg.chip.jpg",
    # "dataset/full/18_1_0_20170109212756182.jpg.chip.jpg",
    # "dataset/full/1_0_2_20161219162726231.jpg.chip.jpg",
    # "dataset/full/1_1_4_20170109194502921.jpg.chip.jpg",
    # "dataset/full/24_0_1_20170116024612640.jpg.chip.jpg",
    # "dataset/full/24_1_0_20170116222405565.jpg.chip.jpg",
    # "dataset/full/24_1_2_20170116173403123.jpg.chip.jpg"
]

importlib.reload(error_analysis)
results_lworst_quality_images = error_analysis.process_images(base_model, low_brisque_images)
In [ ]:
error_analysis.display_grid(results_lworst_quality_images)
In [ ]:

One obvious major shortcoming of this approach is that we're basically excluding a significant proportion of samples basically just because our model performs very poorly on them.

While {TODO}

A production pipeline might be:

  1. Check if image is valid using heuristics (e.g. telling the user to position the camera better etc.)
In [ ]:
 
In [ ]:
 
In [ ]:
 

Augmentation Based Oversampling¶

We'll use augmentation/transforms combined with oversampling to increase the number of samples in underrepresented classes. This approach:

  • allows us to preserve original data characteristics while introducing variability

Potential issues:

  • Risk of overfitting to augmented versions of underrepresented samples
  • Possibility of introducing unintended biases if augmentation isn't carefully balanced
  • May not fully address underlying dataset biases
  • Requires careful monitoring to ensure improved performance across all age groups
In [ ]:
 
In [ ]:
 

Comparing Both Models¶

Let's look at samples that were miss-classified using the initial model but are now correct in the new model:

In [ ]:
merged_data_base
In [112]:
base_data_wrong_pred_df_good_on_improved = merged_data_base[
    ((merged_data_base['gender_pred'] > 0.5) & (merged_data_base['true_gender'] == 0)) |
    ((merged_data_base['gender_pred'] <= 0.5) & (merged_data_base['true_gender'] == 1))
    ]

base_data_wrong_pred_df_good_on_improved = pd.merge(
    base_data_wrong_pred_df_good_on_improved,
    merged_data_improved[['image_path', 'true_gender', 'gender_pred']],
    on='image_path',
    how='left'
)

base_data_wrong_pred_df_good_on_improved = base_data_wrong_pred_df_good_on_improved[
    (((base_data_wrong_pred_df_good_on_improved['true_gender_x'] == 0) & (
                base_data_wrong_pred_df_good_on_improved['gender_pred_x'] >= 0.5)) |
     ((base_data_wrong_pred_df_good_on_improved['true_gender_x'] == 1) & (
                 base_data_wrong_pred_df_good_on_improved['gender_pred_x'] < 0.5)))

    &

    (((base_data_wrong_pred_df_good_on_improved['true_gender_y'] == 0) & (
                base_data_wrong_pred_df_good_on_improved['gender_pred_y'] < 0.5)) |
     ((base_data_wrong_pred_df_good_on_improved['true_gender_y'] == 1) & (
                 base_data_wrong_pred_df_good_on_improved['gender_pred_y'] >= 0.5)))

    ]
# Calculate error magnitude
base_data_wrong_pred_df_good_on_improved['base_error'] = abs(
    base_data_wrong_pred_df_good_on_improved['gender_pred_x'] - base_data_wrong_pred_df_good_on_improved[
        'true_gender_x'])

# Sort by error magnitude (descending) and select top N
N = 5  # Change this to your desired number
top_N_wrong = base_data_wrong_pred_df_good_on_improved.sort_values('base_error', ascending=False).head(N)
improved_image_files = top_N_wrong["image_path"]

top_N_wrong

# Calculate age prediction errors for both models
merged_data_base['age_error'] = abs(merged_data_base['age_pred'] - merged_data_base['true_age'])
merged_data_improved['age_error'] = abs(merged_data_improved['age_pred'] - merged_data_improved['true_age'])

# Merge the datasets
age_comparison = pd.merge(
    merged_data_base[['image_path', 'true_age', 'age_pred', 'age_error']],
    merged_data_improved[['image_path', 'age_pred', 'age_error']],
    on='image_path',
    suffixes=('_base', '_improved')
)

# Calculate error reduction
age_comparison['error_reduction'] = age_comparison['age_error_base'] - age_comparison['age_error_improved']

# Sort by largest improvement and select top N
N = 5  # Change this to your desired number
top_N_age_improved = age_comparison.sort_values('error_reduction', ascending=False).head(N)
improved_age_image_files = top_N_age_improved["image_path"]
top_N_age_improved
Out[112]:
image_path true_age age_pred_base age_error_base age_pred_improved age_error_improved error_reduction
4421 80_1_0_20170110131953974.jpg.chip.jpg 80 30.372637 49.627363 68.726158 11.273842 38.353521
3372 46_1_3_20170120140919993.jpg.chip.jpg 46 12.257863 33.742137 46.783295 0.783295 32.958842
3788 55_0_0_20170117204213768.jpg.chip.jpg 55 17.566719 37.433281 44.354115 10.645885 26.787395
2525 34_1_2_20170108224608753.jpg.chip.jpg 34 6.370270 27.629730 33.015022 0.984978 26.644752
2910 38_1_0_20170117154129371.jpg.chip.jpg 38 14.315865 23.684135 37.683617 0.316383 23.367752
In [ ]:
 
In [113]:
 
  0%|          | 0/500 [00:00<?, ?it/s]
  0%|          | 0/500 [00:00<?, ?it/s]
  0%|          | 0/500 [00:00<?, ?it/s]
  0%|          | 0/500 [00:00<?, ?it/s]
  0%|          | 0/500 [00:00<?, ?it/s]
  0%|          | 0/500 [00:00<?, ?it/s]
  0%|          | 0/500 [00:00<?, ?it/s]
  0%|          | 0/500 [00:00<?, ?it/s]
  0%|          | 0/500 [00:00<?, ?it/s]
  0%|          | 0/500 [00:00<?, ?it/s]
In [114]:
 
  0%|          | 0/500 [00:00<?, ?it/s]
  0%|          | 0/500 [00:00<?, ?it/s]
  0%|          | 0/500 [00:00<?, ?it/s]
  0%|          | 0/500 [00:00<?, ?it/s]
  0%|          | 0/500 [00:00<?, ?it/s]
  0%|          | 0/500 [00:00<?, ?it/s]
  0%|          | 0/500 [00:00<?, ?it/s]
  0%|          | 0/500 [00:00<?, ?it/s]
  0%|          | 0/500 [00:00<?, ?it/s]
  0%|          | 0/500 [00:00<?, ?it/s]
In [153]:
#### Misclassified Age
base_worst_images_age = [
    'dataset/test_2_folds_last/111_1_0_20170120134646399.jpg.chip.jpg',
    'dataset/test_2_folds_last/9_0_0_20170110225030430.jpg.chip.jpg',
    'dataset/test_2_folds_last/41_1_1_20170117021604893.jpg.chip.jpg',
    'dataset/test_2_folds_last/8_0_1_20170114025855492.jpg.chip.jpg',
    'dataset/test_2_folds_last/80_1_0_20170110131953974.jpg.chip.jpg',
    'dataset/test_2_folds_last/15_0_0_20170116201332456.jpg.chip.jpg',

]

base_worst_images_gender = [
     'dataset/test_2_folds_last/26_1_1_20170116154712959.jpg.chip.jpg',
    'dataset/test_2_folds_last/111_1_0_20170120134646399.jpg.chip.jpg',
    'dataset/test_2_folds_last/9_0_0_20170110225030430.jpg.chip.jpg',
    'dataset/test_2_folds_last/8_0_1_20170114025855492.jpg.chip.jpg'
]

results_gender_worst_base = [
    error_analysis.process_image_for_models(f"{img_file}", [base_model, improved_model]) for img_file in
    base_worst_images_gender]

results_age_worst_base = [
    error_analysis.process_image_for_models(f"{img_file}", [base_model, improved_model]) for img_file in
    base_worst_images_age]
  0%|          | 0/500 [00:00<?, ?it/s]
  0%|          | 0/500 [00:00<?, ?it/s]
  0%|          | 0/500 [00:00<?, ?it/s]
  0%|          | 0/500 [00:00<?, ?it/s]
  0%|          | 0/500 [00:00<?, ?it/s]
  0%|          | 0/500 [00:00<?, ?it/s]
  0%|          | 0/500 [00:00<?, ?it/s]
  0%|          | 0/500 [00:00<?, ?it/s]
  0%|          | 0/500 [00:00<?, ?it/s]
  0%|          | 0/500 [00:00<?, ?it/s]
  0%|          | 0/500 [00:00<?, ?it/s]
  0%|          | 0/500 [00:00<?, ?it/s]
  0%|          | 0/500 [00:00<?, ?it/s]
  0%|          | 0/500 [00:00<?, ?it/s]
  0%|          | 0/500 [00:00<?, ?it/s]
  0%|          | 0/500 [00:00<?, ?it/s]
  0%|          | 0/500 [00:00<?, ?it/s]
  0%|          | 0/500 [00:00<?, ?it/s]
  0%|          | 0/500 [00:00<?, ?it/s]
  0%|          | 0/500 [00:00<?, ?it/s]
In [154]:
importlib.reload(error_analysis)
error_analysis.display_grid_comparison(results_gender_worst_base, ['Base Model', 'Improved Model'], comparison_type='gender')
No description has been provided for this image
Figure size: 840x1120 px
In [151]:
importlib.reload(error_analysis)
error_analysis.display_grid_comparison(results_age_worst_base, ['Base Model', 'Improved Model'], comparison_type='age')
No description has been provided for this image
Figure size: 840x1680 px
In [ ]:
 
In [ ]:
 
In [119]:
 
In [259]:
merged_data_improved
Out[259]:
gender_pred age_pred true_gender true_age image_path variance unique_colors entropy brisque_score laplacian_variance ... skin_tone age gender age_group age_bin_raw Images entropy_bin brisque_score_bin laplacian_variance_bin fft_blur_score_bin
0 0.999973 90.525475 1 100 100_1_2_20170105174847679.jpg.chip.jpg 1633.451704 8038 7.538403 51.687888 38.042864 ... 10.5084 100 1 60-inf 90-inf 0 62 146 4 4
1 0.999996 94.688004 1 105 105_1_0_20170112213001988.jpg.chip.jpg 1762.530230 8305 7.506086 14.444382 443.091900 ... 22.4492 105 1 60-inf 90-inf 0 56 9 132 142
2 0.000005 13.253082 0 10 10_0_0_20170103233459275.jpg.chip.jpg 3085.937258 7645 7.854686 33.476290 323.148460 ... 0.0000 10 0 0-18 0-10 0 148 80 112 79
3 0.964019 13.910191 0 10 10_0_0_20170110220111082.jpg.chip.jpg 6558.431532 7217 7.520848 24.230622 901.236160 ... 35.3772 10 0 0-18 0-10 0 58 35 154 154
4 0.158536 8.086036 0 10 10_0_0_20170110220447314.jpg.chip.jpg 2803.122973 611 7.060971 6.895933 1242.605226 ... 54.0000 10 0 0-18 0-10 0 8 2 158 150
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
4610 0.998781 7.660334 1 9 9_1_0_20170109202824646.jpg.chip.jpg 2106.611190 8666 7.609472 14.849944 731.926092 ... 31.7656 9 1 0-18 0-10 0 81 10 150 133
4611 0.999835 13.670305 1 9 9_1_0_20170109203410981.jpg.chip.jpg 2842.062584 8257 7.569806 15.243827 487.166810 ... 26.5600 9 1 0-18 0-10 0 69 11 137 134
4612 0.999962 15.859377 1 9 9_1_0_20170109204249427.jpg.chip.jpg 1838.362976 9038 7.649116 20.562037 634.092552 ... 14.4340 9 1 0-18 0-10 0 93 22 147 138
4613 0.012970 7.215901 1 9 9_1_0_20170109204626343.jpg.chip.jpg 3243.390748 5083 7.412114 26.172752 268.870751 ... 26.9032 9 1 0-18 0-10 0 37 43 99 123
4614 0.430373 5.931127 1 9 9_1_2_20161219190524395.jpg.chip.jpg 2400.697659 6844 7.497291 39.660156 217.529045 ... 19.1000 9 1 0-18 0-10 0 54 111 83 88

4615 rows × 23 columns

Of course, we have specifically selected the best case examples (i.e. where the performance of model has improved the most) which probably gives a much to optimistic picture of the overall improvement (relative to overal increase in accuracy/MAE which is not as signficant).

Instead, we've selected some of the samples our initial model failed on that were unlikely to be mislabeled:

In [272]:
importlib.reload(error_analysis)
error_analysis.evaluate_age_prediction(merged_data_improved["true_age"], merged_data_improved["age_pred"], bins=metrics.DEFAULT_AGE_BINS)
/mnt/v/projects/DL_s3/Notebooks/utils/error_analysis.py:170: FutureWarning: 

`shade` is now deprecated in favor of `fill`; setting `fill=True`.
This will become an error in seaborn v0.14.0; please update your code.

  sns.kdeplot(x='True_Age', y='Error', data=df, ax=axs[0, 0], cmap="YlOrRd", shade=True, cbar=True)
/mnt/v/projects/DL_s3/Notebooks/utils/error_analysis.py:182: FutureWarning: 

`shade` is now deprecated in favor of `fill`; setting `fill=True`.
This will become an error in seaborn v0.14.0; please update your code.

  sns.kdeplot(x='True_Age', y='Predicted_Age', data=df, ax=axs[1, 0], cmap="YlOrRd", shade=True, cbar=True)
/mnt/v/projects/DL_s3/Notebooks/utils/error_analysis.py:209: FutureWarning: The default of observed=False is deprecated and will be changed to True in a future version of pandas. Pass observed=False to retain current behavior or observed=True to adopt the future default and silence this warning.
  ('median', 'median'),
/mnt/v/projects/DL_s3/Notebooks/utils/error_analysis.py:244: UserWarning: set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
  
No description has been provided for this image